"""
The code is released exclusively for review purposes with the following terms:
PROPRIETARY AND CONFIDENTIAL. UNAUTHORIZED USE, COPYING, OR DISTRIBUTION OF THE 
CODE, VIA ANY MEDIUM, IS STRICTLY PROHIBITED. BY ACCESSING THE CODE, THE 
REVIEWERS AGREE TO DELETE THEM FROM ALL MEDIA AFTER THE REVIEW PERIOD IS OVER.
"""

""" Estimate predictions in the neighborhood """
import numpy as np
import sys
sys.path.append("../utilities/")
import os

from joblib import Parallel, delayed
# from sklearn.utils import check_random_state
import yaml
import pickle
import joblib

from utils import (fname_model, fname_base_perts, fname_preds, 
                    create_dir_if_not_exist)

# Pass arguments and run the code
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config_fname")
parser.add_argument("--dataset_key")
parser.add_argument("--model_key")
parser.add_argument("--pert_key")
args = parser.parse_args()

# Load the config file
config = yaml.load(open(
            os.path.join("config", args.config_fname)),
            Loader=yaml.FullLoader)
# config['MeLime_Perturbations']["cnt"]=100
# config['Base_Perturbations']["cnt"]=100

# Load the base perturbations
base_pert_fname = fname_base_perts(config, 
                    args.pert_key, args.dataset_key)+".pkl"
dirname = os.path.join("data", args.dataset_key, "perturbations")
perturbations = pickle.load(open( os.path.join(dirname, base_pert_fname), "rb" ) )
samp_perts_bb = perturbations["samp_perts_bb"]

# Compute predicted outcomes
modelfname = fname_model(config, args.model_key, 
                        args.dataset_key)+".pkl"
dirname = os.path.join("data", args.dataset_key, "models")
bb_model = joblib.load(open(os.path.join(dirname, modelfname), "rb"))

print(config["Bb_Model"][args.model_key]["n_jobs"])
n_data_all = len(samp_perts_bb)

if config["Preds"]["cls"] == "reg":
    def predict_fn(x):
        print(x[1])
        return bb_model.predict(x[0])
else:
    def predict_fn(x):
        print(x[1])
        return bb_model.predict_proba(x[0])[:, config["Preds"]["cls"]].ravel()

# if __name__ == '__main__':
#     # 
y_pred_perts = Parallel(n_jobs=config["Bb_Model"][args.model_key]["n_jobs"])(
                delayed(predict_fn)(x) 
                for x in
                zip(samp_perts_bb, range(n_data_all)))
y_pred_perts = list(y_pred_perts)
# y_pred_perts = np.array(y_pred_perts)             
# y_pred_perts = [bb_model.predict(samp_pert) 
#             for samp_pert in samp_perts_bb]

# dump the predictions
preds_fname = fname_preds(config, args.pert_key, args.model_key, args.dataset_key)+".pkl"
dirname = os.path.join("data", args.dataset_key, "predictions")
create_dir_if_not_exist(dirname)
pickle.dump(y_pred_perts, open( os.path.join(dirname, preds_fname), "wb" ) )
